"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.

Genarris Output Processing and Initial Deduplication Module

This module handles the complete processing pipeline for raw Genarris structure generation
outputs, converting them into standardized formats suitable for ML-based relaxation and
analysis. It implements the critical bridge between structure generation and optimization.

Key Features:
- Efficient parsing of Genarris JSON output files with structure and energy data
- Conversion to standardized Parquet format for high-performance data access
- Initial structure deduplication to remove obvious duplicates before expensive ML relaxation
- Distributed processing support with partitioning for parallel execution

Processing Pipeline:
1. Raw Data Parsing: Extract structures and metadata from Genarris JSON files
2. Format Conversion: Convert to CIF strings and structured data records
3. Initial Filtering: Remove structures with obvious defects or unrealistic properties
4. Pre-relaxation Deduplication: Eliminate duplicate structures using crystallographic comparison
5. Partitioning: Distribute structures across partitions for efficient parallel processing
6. Parquet Export: Save processed data in optimized columnar format
"""

from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any

import pandas as pd
from ase.io.jsonio import decode
from fairchem.applications.fastcsp.core.utils.deduplicate import deduplicate_structures
from fairchem.applications.fastcsp.core.utils.logging import get_central_logger
from fairchem.applications.fastcsp.core.utils.slurm import (
    get_process_slurm_config,
    submit_slurm_jobs,
)
from fairchem.applications.fastcsp.core.utils.structure import get_partition_id
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm

if TYPE_CHECKING:
    from pathlib import Path


def get_pre_relax_filter_config(config: dict[str, Any]) -> dict[str, Any]:
    """
    Extract and validate pre-relaxation filtering parameters from workflow configuration.

    This function processes configuration parameters that control the initial deduplication
    and filtering of structures generated by Genarris before they undergo expensive ML
    relaxation.

    Args:
        config: Complete workflow configuration dictionary

    Returns:
        dict: Pre-relaxation filtering parameters containing:
            - ltol: Lattice parameter tolerance for structure matching (default: 0.2)
            - stol: Site position tolerance for structure matching (default: 0.3)
            - angle_tol: Lattice angle tolerance in degrees (default: 5.0)
            - npartitions: Number of partitions for parallel processing (default: 1)

    Configuration Guidelines:
        - Stricter tolerances (lower ltol/stol) preserve more unique structures
        - Higher npartitions improves parallelization for large datasets
        - Parameters should be consistent with post-relaxation filtering for continuity

    Notes:
        - Pre-relaxation filtering uses looser tolerances than post-relaxation
        - Conservative filtering prevents loss of potentially important polymorphs
        - Partitioning strategy affects memory usage and parallel efficiency
    """
    match_config = config.get("pre_relaxation_filter", {})
    return {
        "ltol": match_config.get("ltol", 0.2),  # default lattice tolerance
        "stol": match_config.get("stol", 0.3),  # default site tolerance
        "angle_tol": match_config.get(
            "angle_tol", 5
        ),  # default angle tolerance in degrees
        "npartitions": match_config.get(
            "npartitions", 1
        ),  # default number of partitions
    }


def structure_to_row(
    hash_id: str, struct_dict: dict, mol_id: str, z_val: int, npartitions: int = 1000
) -> dict:
    """
    Convert structure data to standardized DataFrame row format.
            - mol_id: Molecule identifier
            - z: Number of formula units per unit cell
            - structure_id: Unique structure identifier
            - formula: Reduced chemical formula
            - n_atoms: Total number of atoms in unit cell
            - volume: Unit cell volume in Ų
            - cif: Crystal structure in CIF format
            - partition_id: Partition assignment for parallel processing
            - structure: Pymatgen Structure object for analysis

    Processing Steps:
        1. Decode ASE JSON format to Atoms object
        2. Convert to pymatgen Structure for analysis
        3. Extract chemical composition and geometric properties
        4. Generate CIF string representation
        5. Assign consistent partition ID for distributed processing
    """
    atoms = decode(json.dumps(struct_dict))
    structure = AseAtomsAdaptor.get_structure(atoms)
    formula = structure.composition.reduced_formula
    n_atoms = len(structure)
    volume = structure.volume
    cif_str = structure.to(fmt="cif")

    hash_id_ = f"{hash_id}_{mol_id}_{z_val}"

    return {
        "mol_id": mol_id,
        "z": z_val,
        "structure_id": hash_id_,
        "formula": formula,
        "n_atoms": n_atoms,
        "volume": volume,
        "cif": cif_str,
        "partition_id": get_partition_id(hash_id_, npartitions),
        "structure": structure,
    }


def process_genarris_outputs_single(
    base_dir: Path,
    output_dir: Path,
    ltol: float = 0.2,
    stol: float = 0.3,
    angle_tol: float = 5,
    npartitions: int = 1000,
):
    """
    Process Genarris output files from a single molecular conformer directory.

    Converts raw Genarris JSON structure files into standardized parquet format
    with structure deduplication and metadata extraction. This function handles
    the complex directory structure of Genarris outputs and transforms them into
    a format suitable for downstream ML processing.

    Args:
        base_dir: Root directory containing Genarris output structure
                 Expected structure: mol_id/conf_id/z_val/symm_rigid_press/structures.json
        output_dir: Directory where processed parquet files will be saved
        npartitions: Number of partitions for distributed processing (default: 1000)
        ltol: Lattice parameter tolerance for structure deduplication (default: 0.2)
        stol: Site tolerance for structure deduplication (default: 0.3)
        angle_tol: Angle tolerance for structure deduplication (default: 5°)

    Processing Workflow:
        1. Scan directory structure for structures.json files
        2. Extract mol_id and Z values from directory hierarchy
        3. Parse JSON structure data and convert to standardized format
        4. Apply deduplication using pymatgen
        5. Save results in partitioned parquet format for efficient access

    Output Format:
        Creates parquet files partitioned by partition_id containing:
        - structure_id: Unique identifier for each structure
        - mol_id: Original molecule identifier
        - z: Number of formula units per unit cell
        - formula: Reduced chemical formula
        - n_atoms: Total atoms in unit cell
        - volume: Unit cell volume
        - cif: Structure in CIF format
        - group_index: Deduplication group assignment
    """
    logger = get_central_logger()
    logger.info(f"Processing {base_dir}")
    json_files = list(base_dir.glob("**/symm_rigid_press/structures.json"))
    logger.info(f"Found {len(json_files)} files / {base_dir}")
    all_rows = []

    for file_path in tqdm(json_files, desc="Processing files"):
        try:
            z_val = int(file_path.parents[2].name)
            mol_id = file_path.parents[4].name
        except Exception as e:
            logger.warning(f"Failed to extract mol_id or z from path {file_path}: {e}")
            continue

        with file_path.open("r") as f:
            struct_data = json.load(f)

        for hash_id, struct_dict in tqdm(
            struct_data.items(),
            desc="Processing structures",
            total=len(struct_data),
        ):
            try:
                row = structure_to_row(hash_id, struct_dict, mol_id, z_val, npartitions)
                all_rows.append(row)
            except Exception as e:
                logger.warning(
                    f"Failed to parse structure {hash_id} in {file_path}: {e}"
                )

    structures_df = pd.DataFrame(all_rows)
    structures_df = deduplicate_structures(structures_df, ltol, stol, angle_tol)
    structures_df = structures_df.drop(columns=["structure"])
    structures_df.to_parquet(
        output_dir,
        compression="zstd",
        partition_cols=["partition_id"],
    )
    logger.info(f"Saved {len(all_rows)} structures to {output_dir}")


def process_genarris_outputs(
    input_dir: Path,
    output_dir: Path,
    pre_relax_config: dict[str, Any],
    ltol: float = 0.2,
    stol: float = 0.3,
    angle_tol: float = 5,
    npartitions: int = 1000,
):
    """
    Batch process multiple Genarris output directories using SLURM parallel execution.

    Args:
        input_dir: Root directory containing multiple molecule directories
        output_dir: Output directory where processed results will be saved
        pre_relax_config: Configuration dictionary containing SLURM and processing parameters
        ltol: Lattice parameter tolerance for structure deduplication
        stol: Site tolerance for structure deduplication
        angle_tol: Angle tolerance for structure deduplication
        npartitions: Number of partitions for distributed processing

    Returns:
        List of submitit job objects for monitoring execution status
    """
    logger = get_central_logger()

    # Get SLURM configuration
    slurm_params = get_process_slurm_config(pre_relax_config)

    job_args = []
    for mol_dir in input_dir.iterdir():
        for conf_dir in mol_dir.iterdir():
            processed_dir = output_dir / mol_dir.name / conf_dir.name

            if (
                processed_dir.exists()
                and len(list(processed_dir.glob("*/*.parquet"))) > 0
            ):
                logger.info(
                    f"Skipping {conf_dir} because {processed_dir} already exists"
                )
                continue

            job_args.append(
                (
                    process_genarris_outputs_single,
                    (conf_dir, processed_dir, ltol, stol, angle_tol, npartitions),
                    {},
                )
            )

    return submit_slurm_jobs(
        job_args,
        output_dir=output_dir / "slurm",
        **slurm_params,
    )
